import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import ast
import yaml
import uuid
from tqdm import tqdm
from sklearn.preprocessing import StandardScaler

with open("./task_specs", "r") as file:
    task_specs = yaml.safe_load(file.read())

################# UTILITIES 
def flatten(list_of_lists):
    return [item for sublist in list_of_lists for item in sublist]

def solution_edit_distance(soln1, soln2):
    if len(soln1) == len(soln2) and "deck" not in soln2:
        return sum(np.array(soln1) != np.array(soln2))
    else:
        return None 


def solution_distances(row, reference_soln_index):
    assert(reference_soln_index in {"first", "last"})
    
    int_solutions = row.parsed_intermediateSolutions
    reference_soln = []
    soln_distances = []
    
    if reference_soln_index == "first":
        for student in task_specs['phase_2'][row.task_index]['students']:
            reference_soln.append(next((k for k, v in int_solutions[int(row.ROUNDFEAT_SOLNS_index_first_complete)]['solution'].items() if student in v), None))
            
        for index in range(int(row.ROUNDFEAT_SOLNS_index_first_complete)+1, len(int_solutions)):
            current_soln = []
            for student in task_specs['phase_2'][row.task_index]['students']:
                current_soln.append(next((k for k, v in int_solutions[int(index)]['solution'].items() if student in v), None))

            soln_distances.append(solution_edit_distance(reference_soln, current_soln))
            
    elif reference_soln_index == "last":
        for student in task_specs['phase_2'][row.task_index]['students']:
            reference_soln.append(next((k for k, v in int_solutions[-1]['solution'].items() if student in v), None))
            
        for index in range(int(row.ROUNDFEAT_SOLNS_index_first_complete), len(int_solutions)):
            current_soln = []
            for student in task_specs['phase_2'][row.task_index]['students']:
                current_soln.append(next((k for k, v in int_solutions[int(index)]['solution'].items() if student in v), None))

            soln_distances.append(solution_edit_distance(reference_soln, current_soln))
    

    return soln_distances



################# DATA PROCESSING
def process_intermediate_solution(intermediate_solution, payoff):
    nominal_score = 0
    for room in intermediate_solution['solution']:
        if room != "deck":
            for student in intermediate_solution['solution'][room]: 
                nominal_score += payoff[student][int(room)]

    real_score = nominal_score - 100 * intermediate_solution['nConstraintsViolated']

    return nominal_score, real_score, intermediate_solution['nConstraintsViolated'], intermediate_solution['completeSolution'],intermediate_solution['nConstraintsViolated']==0, intermediate_solution['at']

def generate_round_data(row, phase=2): 
    """
    Phase must be '1' or '2'
    """
    
    round_payoff_dict = task_specs['phase_{}'.format(phase)][row.task_index]['payoff']
    df_round_data = pd.DataFrame(data=np.array([process_intermediate_solution(x, round_payoff_dict) for x in row.parsed_intermediateSolutions]), 
                                 columns=["nominal_score", "real_score", "num_violations", "is_complete", "is_feasible", "timestamp"])

    #Data types
    df_round_data['nominal_score'] = df_round_data['nominal_score'].astype(int)
    df_round_data['real_score'] = df_round_data['real_score'].astype(int)
    df_round_data['num_violations'] = df_round_data['num_violations'].astype(int)
    
    df_round_data['is_complete'] = [ast.literal_eval(x) for x in df_round_data['is_complete']]
    df_round_data['is_feasible'] = [ast.literal_eval(x) for x in df_round_data['is_feasible']]
    
    df_round_data['timestamp'] = pd.to_datetime(df_round_data['timestamp'])
    
    action_takers = np.array([x['subjectId'] for x in yaml.safe_load(row.log) if x['verb'] == "movedStudent"])
    
    if len(action_takers) == len(df_round_data):
        df_round_data['subject_id'] = action_takers
    else: 
        df_round_data['subject_id'] = None
    
    return df_round_data

def round_features(df_round_data, parsed_log, task_index, task_specs):
    # Data prep
    df_round_data['delta_violations'] = df_round_data['num_violations'].diff()
    df_round_data.loc[0,'delta_violations'] = df_round_data.loc[0,'num_violations']
    df_round_data['delta_nominal'] = df_round_data['nominal_score'].diff().fillna(df_round_data['nominal_score'])
    df_round_data['delta_real'] = df_round_data['real_score'].diff().fillna(df_round_data['real_score'])
    df_round_data['nominal_gained_per_violation'] = df_round_data['delta_nominal'] / df_round_data['delta_violations']
    df_round_data['real_gained_per_violation'] = df_round_data['delta_real'] / df_round_data['delta_violations']

    #Basic features
    num_inter_soln = len(df_round_data)

    #Adding this because round BbPxDKJjA52xHhYx2 didn't have a complete solution
    try:
        index_first_complete = df_round_data.query("is_complete == 1").index[0]
    except:
        index_first_complete = None

    if len(df_round_data.query("is_complete == 1")) != 0: 
        highest_complete_score = df_round_data.query("is_complete == 1")['real_score'].max()
    else:
        highest_complete_score = df_round_data['real_score'].max()


    bool_submitted_highest_complete_score = highest_complete_score == df_round_data.real_score.values[-1]

    round_summary = {"ROUNDFEAT_SOLNS_num_inter_soln":num_inter_soln, 
                     "ROUNDFEAT_SOLNS_index_first_complete":index_first_complete,
                     "ROUNDFEAT_SCORES_highest_complete_score":highest_complete_score,
                     "ROUNDFEAT_SCORES_bool_submitted_highest_complete_score":bool_submitted_highest_complete_score}

    return round_summary

def standardize_data(df, standardized_columns, base_categories={'real-group', 'solo'}):
    for complexity in df['complexity_cat'].cat.categories:
        for column in standardized_columns:
            standard_scaler = StandardScaler()
            standard_scaler.fit(np.array(df.query("group_formation in @base_categories and complexity == @complexity")[column]).reshape(-1,1))
            df.loc[df['complexity_cat'] == complexity, "zscore_{}".format(column)] = standard_scaler.transform(np.array(df.loc[df['complexity_cat'] == complexity, column]).reshape(-1, 1)) 

def base_processing(phase2_filepath):
    print("Base processing...")
    df = pd.read_csv(phase2_filepath)
    df['complexity_cat'] = pd.Categorical(df['complexity'], categories=['Very low', 'Low', 'Moderate', 'High', 'Very high'], ordered=True)
    df['block_cat'] = pd.Categorical(df['block'], categories=['ll', 'lh', 'ml', 'mh', 'hl', 'hh'], ordered=True)
    df['skill_cat'] = pd.Categorical([x[0] for x in df['block']], categories=['l', 'm', 'h'], ordered=True)
    df['social_cat'] = pd.Categorical([x[1] for x in df['block']], categories=['l', 'h'], ordered=True)

    #Generate dicts of log and solutions 
    print("Parsing logs...")
    df['parsed_log'] = df['log'].apply(lambda x: ast.literal_eval(x))
    df['parsed_intermediateSolutions'] = df['intermediateSolutions'].apply(lambda x: ast.literal_eval(x))

    #Generate dataframes of round data 
    print("Parsing round data...")
    df['round_data'] = df.apply(lambda x: generate_round_data(x), axis=1)
    df.loc[[x.subject_id.values[0] is None for x in df.round_data.values], "corrupted_data"] = True
    df.loc[[x.subject_id.values[0] is not None for x in df.round_data.values], "corrupted_data"] = False


    #Generate features from round data 
    round_summaries = {}
    for row in tqdm(df.itertuples()): 
        try:
            round_summary = round_features(row.round_data, row.parsed_log, row.task_index, task_specs)
            round_summaries.update({row.round_id:round_summary})
        except: 
            print(row.round_id)
            raise 

    df_round_summaries = pd.DataFrame(list(round_summaries.values()))
    df_round_summaries['round_id'] = list(round_summaries.keys())

    temp_len_df = len(df)
    df = df.merge(df_round_summaries, on="round_id", how="inner")
    assert temp_len_df == len(df)

    return df 

def additional_features(df):
    #"data_preprocessing"
    print("Processing time features...")
    #For rounds that have no complete solution (2 rounds), set the last solution to be the index of the "complete solution"
    incomplete_rounds = set(df.loc[df['ROUNDFEAT_SOLNS_index_first_complete'].isnull(), 'round_id'].values)
    df.loc[df['ROUNDFEAT_SOLNS_index_first_complete'].isnull(), 'ROUNDFEAT_SOLNS_index_first_complete'] = (df.loc[df['ROUNDFEAT_SOLNS_index_first_complete'].isnull(), 'ROUNDFEAT_SOLNS_num_inter_soln'] - 1).astype(int)
    
    df.loc[[x in incomplete_rounds for x in df['round_id']], 'best_solution_timestamp'] = df.loc[[x in incomplete_rounds for x in df['round_id']]].apply(lambda x: x.round_data.timestamp[x.round_data.rename_axis('index').sort_values(by=["real_score", "index"], ascending=[False,True]).index[0]], axis=1)
    df.loc[[x in incomplete_rounds for x in df['round_id']], 'best_solution_index'] = df.loc[[x in incomplete_rounds for x in df['round_id']]].apply(lambda x: x.round_data.rename_axis('index').sort_values(by=["real_score", "index"], ascending=[False,True]).index[0], axis=1)

    df.loc[[x not in incomplete_rounds for x in df['round_id']], 'best_solution_timestamp'] = df.loc[[x not in incomplete_rounds for x in df['round_id']]].apply(lambda x: x.round_data.timestamp[x.round_data.query("is_complete == 1").rename_axis('index').sort_values(by=["real_score", "index"], ascending=[False,True]).index[0]], axis=1)
    df.loc[[x not in incomplete_rounds for x in df['round_id']], 'best_solution_index'] = df.loc[[x not in incomplete_rounds for x in df['round_id']]].apply(lambda x: x.round_data.query("is_complete == 1").rename_axis('index').sort_values(by=["real_score", "index"], ascending=[False,True]).index[0], axis=1)

    df['round_start_timestamp'] = pd.to_datetime(df['startTimeAt'], utc=True)
    df['first_step_timestamp'] = df.apply(lambda x: x.round_data.timestamp[0], axis=1)
    df['best_solution_timestamp'] = pd.to_datetime(df['best_solution_timestamp'])
    df['first_complete_timestamp'] = df.apply(lambda x: x.round_data.timestamp[x.ROUNDFEAT_SOLNS_index_first_complete], axis=1)
    df['final_solution_timestamp'] = df.apply(lambda x: x.round_data.timestamp[x.ROUNDFEAT_SOLNS_num_inter_soln-1], axis=1)

    
    df['time_to_first_step'] = ( df['first_step_timestamp'] - df['round_start_timestamp']).astype(int) / (60 * 10**9)
    df['time_from_first_step_to_first_complete'] = ( df['first_complete_timestamp'] - df['first_step_timestamp']).astype(int) / (60 * 10**9)
    df['time_from_first_complete_to_final'] = ( df['final_solution_timestamp'] - df['first_complete_timestamp']).astype(int) / (60 * 10**9)
    df['time_from_final_to_submit'] = df['round_duration'] - (df['time_to_first_step'] + df['time_from_first_step_to_first_complete'] + df['time_from_first_complete_to_final'])
    df['time_to_best_solution'] = ( df['best_solution_timestamp'] - df['round_start_timestamp']).astype(int) / (60 * 10**9)
    df['time_from_first_step_to_best'] = df['time_to_best_solution'] - df['time_to_first_step']
    df['time_from_best_to_final'] = (df['time_to_first_step'] + df['time_from_first_step_to_first_complete'] + df['time_from_first_complete_to_final']) - df['time_to_best_solution'] 

    #Solution-related features 
    print("Processing solution features...")
    df['intermediate_solution_pace'] = df['ROUNDFEAT_SOLNS_num_inter_soln'] / df['round_duration']
    df['first_complete_is_final'] = df['first_complete_timestamp'] == df['final_solution_timestamp']

    df['solution_distances_from_first'] = df.apply(lambda x: solution_distances(x, "first"), axis=1)
    df['solution_distances_from_last'] = df.apply(lambda x: solution_distances(x, "last"), axis=1)

    df['max_soln_dist'] = [np.max([y for y in x if y is not None]) if len(x) != 0 else None for x in df['solution_distances_from_first']]
    df['max_soln_dist_fillna'] = df['max_soln_dist'].fillna(0)

    df['pct_solns_dist2_from_last'] = [np.mean(np.array([y for y in x if y is not None ]) <= 2) for x in df['solution_distances_from_last']]

    df['normalized_best_score'] = df['ROUNDFEAT_SCORES_highest_complete_score'] / df['optimal_solution']

    return df 


def generate_nominal_groups(df_rounds, players_filepath, phase1_filepath, n_nominal_groups=5000):

    df_players = pd.read_csv(players_filepath,index_col='workerId')
    df_phase1 = pd.read_csv(phase1_filepath).rename(columns={'workerIds':'workerId'})
    df_efficiency_phase1 = df_phase1.query("complexity == 'Moderate'").groupby("workerId")[['normalized_score', 'round_duration']].sum()
    df_efficiency_phase1['phase1_efficiency'] = df_efficiency_phase1['normalized_score'] / df_efficiency_phase1['round_duration']
    df_players['phase1_efficiency'] = df_efficiency_phase1['phase1_efficiency']
    df_players['phase1_duration'] = df_efficiency_phase1['round_duration']


    df_players_solo = df_players[(df_players.participated_in_step2==True) & 
                                 (df_players.step2_team==False) & 
                                 (df_players.player_valid==True) & 
                                 (df_players.index.isin(df_rounds.workerIds))]



    nominal_definitions = {"nominal-group-efficiency":"phase1_efficiency",
                           "nominal-group-duration":"phase1_duration", 
                           "nominal-group-score":"skill_CSOP",
                           "nominal-group-random":None}
    
    nominal_block_sizes = dict((n_nominal_groups * df_rounds.query("group_formation == 'real-group'")['block'].value_counts() / df_rounds.query("group_formation == 'real-group'")['block'].value_counts().sum()).astype(int))
    
    
    nominal_samples = []
    i = 0
    
    for (nominal_group_name,nominal_sorting_col) in tqdm(nominal_definitions.items()):
        for block in tqdm(nominal_block_sizes.keys()):
            block_samples = []
            n_teams_sampled_in_block = 0
            while n_teams_sampled_in_block < nominal_block_sizes[block]:
#                 print("{} -- {} -- {}".format(nominal_group_name, block, n_teams_sampled_in_block))
                player_pool_size = len(df_players_solo.query("block == @block"))
        
                if nominal_group_name == "nominal-group-random":
                    nominal_worker_ids = set((df_players_solo.query("block == @block")
                                           .sample(frac=1, random_state=i)
                                           .iloc[:(player_pool_size // 3)*3]
                                           .assign(nominal_group_id = flatten([[x]*3 for x in range(player_pool_size // 3)]))
                                           .drop_duplicates("nominal_group_id", keep="first")
                                           .index))
                    
                else:
                    nominal_worker_ids = set((df_players_solo.query("block == @block")
                                               .sample(frac=1, random_state=i)
                                               .iloc[:(player_pool_size // 3)*3]
                                               .assign(nominal_group_id = flatten([[x]*3 for x in range(player_pool_size // 3)]))
                                               .sort_values(nominal_sorting_col, ascending = (nominal_sorting_col == 'phase1_duration'))
                                               .drop_duplicates("nominal_group_id", keep="first")
                                               .index))
        
                game_id_map = dict(zip(nominal_worker_ids, [uuid.uuid4().hex for x in range(len(nominal_worker_ids))]))
                round_id_map = dict(zip(nominal_worker_ids, [str(np.random.randint(0,10000)) for x in range(len(nominal_worker_ids))]))
        
                block_samples.append((df_rounds.query("workerIds in @nominal_worker_ids")
                                       .assign(game_id = lambda x: x['workerIds'].map(game_id_map))
                                       .assign(round_id = lambda x: "nominal_" + str(i) + "_" + x['workerIds'].map(round_id_map) + "_" + x['round_id'])
                                       .assign(nominal_group = True)
                                       .assign(group_formation = nominal_group_name)
                                       .assign(nominal_block=block)
                                      ))
                i += 1
                n_teams_sampled_in_block += player_pool_size // 3
                
            nominal_samples = nominal_samples + block_samples
        
    
    df_nominal = pd.concat(nominal_samples, ignore_index=True)
    return df_nominal    